"""train process"""

import argparse
import os
import random
import time

import numpy as np
import torch
from torch import nn
from torch.utils import data as Data

from src import (
    calculate_l2_error,
    create_datasets,
    load_yaml_config,
    pad_collate,
    tokmak_model,
)

seed = 123456
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


def parse_args():
    """Parse input args"""
    parser = argparse.ArgumentParser(description="tokmak train")
    parser.add_argument(
        "--config_file_path", type=str, default="./configs/tokamak.yaml"
    )
    parser.add_argument(
        "--device_target",
        type=str,
        default="Ascend",
        choices=["GPU", "Ascend"],
        help="The target device to run, support 'Ascend', 'GPU'",
    )
    parser.add_argument(
        "--device_id", type=int, default=0, help="ID of the target device"
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="GRAPH",
        choices=["GRAPH", "PYNATIVE"],
        help="Running in GRAPH_MODE OR PYNATIVE_MODE",
    )
    parser.add_argument(
        "--save_graphs",
        type=bool,
        default=False,
        choices=[True, False],
        help="Whether to save intermediate compilation graphs",
    )
    parser.add_argument("--save_graphs_path", type=str, default="./graphs")

    input_args = parser.parse_args()
    return input_args


def train_process(
    model,
    cur_steps,
    optimizer,
    scheduler,
    train_loader,
    out_channels,
    lossfn_config,
):
    """
    Args:
        ...
    Returns:
        ...
    """
    model.train()
    epoch_train_loss = 0
    for X, valid_len, valid_channels, h5_files in train_loader:
        optimizer.zero_grad(set_to_none=True)
        X = X.cuda()
        valid_len = valid_len.cuda()
        valid_channels = valid_channels.cuda()
        device = X.device
        dtype = X.dtype
        batch_size = X.shape[0]
        cur_steps += batch_size
        enc_inputs = X[:, :, :-out_channels]
        label = X[:, :, -out_channels:]
        valid_channels = valid_channels[:, -out_channels:]
        padded_values = torch.zeros(
            label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
        )
        dec_padded_inputs = torch.cat((padded_values, label), 1)
        dec_inputs = dec_padded_inputs[:, :-1, :]
        l2_loss = calculate_l2_error(
            model,
            enc_inputs,
            dec_inputs,
            cur_steps,
            label,
            valid_len,
            valid_channels,
            lossfn_config,
        )
        l2_loss.backward()
        optimizer.step()
        scheduler.step()
        epoch_train_loss += l2_loss.detach().item()
    return epoch_train_loss, cur_steps


@torch.no_grad()
def eval_process(model, val_loader, out_channels, lossfn_config):
    # cur_steps is
    """ """
    model.eval()
    epoch_eval_loss = 0
    for X, valid_len, valid_channels, h5_files in val_loader:
        X = X.cuda()
        valid_len = valid_len.cuda()
        valid_channels = valid_channels.cuda()
        device = X.device
        dtype = X.dtype
        batch_size = X.shape[0]
        enc_inputs = X[:, :, :-out_channels]
        label = X[:, :, -out_channels:]
        valid_channels = valid_channels[:, -out_channels:]
        padded_values = torch.zeros(
            label.shape[0], 1, label.shape[-1], device=device, dtype=dtype
        )
        dec_padded_inputs = torch.cat((padded_values, label), 1)
        dec_inputs = dec_padded_inputs[:, :-1, :]
        # cur_steps is meaningless when model evaluation.
        cur_steps = 0
        l2_loss = calculate_l2_error(
            model,
            enc_inputs,
            dec_inputs,
            cur_steps,
            label,
            valid_len,
            valid_channels,
            lossfn_config,
        )
        epoch_eval_loss += l2_loss.detach().item()
    return epoch_eval_loss


def train():
    """Train and evaluate the pinns network"""
    # load configurations
    config = load_yaml_config(args.config_file_path)
    train_paras = config["data"]["train"]
    val_paras = config["data"]["validation"]
    loss_paras = config["loss"]["train"]
    # For next improve, please not rewrite this.
    lossfn_config = {}
    lossfn_config["limiter_steps"] = loss_paras["limiter_steps"]

    # create dataset
    train_set, val_set = create_datasets(config, is_debug=train_paras["is_debug"])
    train_loader = Data.DataLoader(
        train_set,
        batch_size=train_paras["batch_size"],
        num_workers=train_paras["num_workers"],
        collate_fn=pad_collate,
    )
    train_steps_per_epoch = len(train_loader) + (train_paras["num_workers"] - 1)

    val_loader = Data.DataLoader(
        val_set,
        batch_size=val_paras["batch_size"],
        num_workers=val_paras["num_workers"],
        collate_fn=pad_collate,
    )
    val_steps_per_epoch = len(val_loader) + (train_paras["num_workers"] - 1)

    # define models and optimizers
    model_params = config["model"]
    optim_params = config["optimizer"]
    model = tokmak_model(
        in_channels=model_params["in_channels"],
        hidden_size=model_params["hidden_size"],
        num_layers=model_params["num_layers"],
        dropout_rate=model_params["dropout_rate"],
        out_channels=model_params["out_channels"],
        noise_ratio=model_params["noise_ratio"],
    )
    model.cuda()
    out_channels = model_params["out_channels"]

    # define optimizer & scheduler
    optimizer_fn = torch.optim.SGD
    optimizer = optimizer_fn(
        model.parameters(),
        lr=float(optim_params["lr"]),
        weight_decay=optim_params["weight_decay"],
    )
    scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
    scheduler_fn = torch.optim.lr_scheduler.OneCycleLR
    scheduler = scheduler_fn(
        optimizer,
        max_lr=float(optim_params["lr"]),
        steps_per_epoch=train_steps_per_epoch,
        epochs=train_paras["epochs"],
    )

    epochs = config["data"]["train"]["epochs"]
    cur_steps = 0
    for epoch in range(1, 1 + epochs):
        # train
        local_time_beg = time.time()
        epoch_train_loss, cur_steps = train_process(
            model,
            cur_steps,
            optimizer,
            scheduler,
            train_loader,
            out_channels,
            lossfn_config,
        )
        epoch_train_loss / train_steps_per_epoch
        local_time_end = time.time()
        epoch_seconds = (local_time_end - local_time_beg) * 1000
        step_seconds = epoch_seconds / (train_steps_per_epoch * 1000)
        step_train_loss = epoch_train_loss / train_steps_per_epoch
        # step_train_loss differ from the real train_loss.
        print(
            f"epoch: {epoch} train loss: {step_train_loss} "
            f"epoch time: {epoch_seconds:5.3f}s step time: {step_seconds:5.3f} ms"
        )

        if epoch % config["summary"]["eval_interval_epochs"] == 0:
            eval_time_start = time.time()
            epoch_val_loss = eval_process(
                model, val_loader, out_channels, lossfn_config
            )
            step_val_loss = epoch_val_loss / val_steps_per_epoch
            epoch_val_seconds = (time.time() - eval_time_start) * 1000
            step_val_seconds = epoch_val_seconds / (val_steps_per_epoch * 1000)
            print(
                f"epoch: {epoch} val loss: {step_val_loss} "
                f"evaluation time: {epoch_val_seconds:5.3f}s step time: {step_val_seconds:5.3f}ms"
            )
            if config["summary"]["save_ckpt"]:
                tempdir = config["summary"]["ckpt_dir"]
                os.makedirs(tempdir, exist_ok=True)
                torch.save(
                    {"epoch": epoch, "model_state": model.state_dict()},
                    os.path.join(
                        tempdir,
                        f"{epoch}-{step_train_loss:.5f}-{step_val_loss:.5f}.pt",
                    ),
                )


if __name__ == "__main__":
    from src import log_config

    log_config("./logs", "tokmak")
    print("pid:", os.getpid())
    args = parse_args()
    print(f"Running in {args.mode.upper()} mode, using device id: {args.device_id}.")
    #    use_ascend = context.get_context(attr_key='device_target') == "Ascend"
    train()
